import torch
from torch import nn
from .Loss import *

class SpatialTransformer(nn.Module):

    def __init__(self, size, mode='bilinear'):
        super().__init__()

        self.mode = mode

        # create sampling grid
        vectors = [torch.arange(0, s) for s in size]
        grids = torch.meshgrid(vectors,indexing=None)
        grid = torch.stack(grids)
        grid = torch.unsqueeze(grid, 0)
        grid = grid.type(torch.FloatTensor)
        self.register_buffer('grid', grid)

    def forward(self, src, flow, return_phi=False):
        # new locations
        new_locs = self.grid + flow
        shape = flow.shape[2:]

        # need to normalize grid values to [-1, 1] for resampler
        for i in range(len(shape)):
            new_locs[:, i, ...] = 2 * (new_locs[:, i, ...] / (shape[i] - 1) - 0.5)

        if len(shape) == 2:
            new_locs = new_locs.permute(0, 2, 3, 1)
            new_locs = new_locs[..., [1, 0]]
        elif len(shape) == 3:
            new_locs = new_locs.permute(0, 2, 3, 4, 1)
            new_locs = new_locs[..., [2, 1, 0]]

        if return_phi:
            return F.grid_sample(src, new_locs, align_corners=True, mode=self.mode), new_locs
        else:
            return F.grid_sample(src, new_locs, align_corners=True, mode=self.mode)

class URED(nn.Module):

    def __init__(self, dnn: nn.Module, config:dict):
        super(URED, self).__init__()
        self.dnn = dnn
        self.gamma = torch.tensor(config['gamma_inti'], dtype=torch.float32)
        self.tau = torch.tensor(config['tau_inti'], dtype=torch.float32)
        self.inshape = config['image_shape']
        self.resize = ResizeTransform(1/2, 3)
        self.max_iter = config['iteration']
        self.weight_grad = config['weight_grad']
        self.transformer = SpatialTransformer(self.inshape)
        
    def get_grad(self, field, in_image, fixed):
        image = in_image.clone().detach()
        no_grad_field = field.clone().detach()

        no_grad_field.requires_grad=True
        no_grad_field_full = self.resize(no_grad_field)
        image_pred = self.transformer(image, no_grad_field_full, return_phi=False)

        loss_func = GCC()
        grad_func = Grad('l2', loss_mult=2).loss
        loss = loss_func(image_pred,fixed) + self.weight_grad*grad_func(no_grad_field)
        loss.backward()
        soft_dc = no_grad_field.grad
        return soft_dc

    def forward(self, field, moving, fixed, interation, flag):
        
        gamma_recent = 0.5*(1+np.cos(interation*np.pi/self.max_iter)) * self.gamma
    
        if flag == "forward":
            torch.set_grad_enabled(True)
            delta_g = self.get_grad(field, moving, fixed)
            torch.set_grad_enabled(False)
            
        if flag == "backward":
            delta_g = self.get_grad(field, moving, fixed)

        xSubD = self.tau * self.dnn(field)
        xnext  =  field - gamma_recent * (delta_g.detach() + xSubD)

        return xnext

    
class ResizeTransform(nn.Module):
    def __init__(self, vel_resize, ndims):
        super().__init__()
        self.factor = 1.0 / vel_resize
        self.mode = 'linear'
        if ndims == 2:
            self.mode = 'bi' + self.mode
        elif ndims == 3:
            self.mode = 'tri' + self.mode

    def forward(self, x):
        if self.factor < 1:
            x = F.interpolate(x, align_corners=True, scale_factor=self.factor, mode=self.mode)
            x = self.factor * x

        elif self.factor > 1:
            x = self.factor * x
            x = F.interpolate(x, align_corners=True, scale_factor=self.factor, mode=self.mode)

        return x
    
def anderson_solver(f, x0, m=5, lam=1e-4, max_iter=50, tol=1e-4, beta=1.0):
    """ Anderson's acceleration for fixed point iteration. """

    bsz, C, H, W, D  = x0.shape
    X = torch.zeros(bsz, m, C * H * W * D, dtype=x0.dtype, device=x0.device)
    F = torch.zeros(bsz, m, C * H * W * D, dtype=x0.dtype, device=x0.device)

    X[:, 0], F[:, 0] = x0.view(bsz, -1), f(x0, 0).view(bsz, -1)
    X[:, 1], F[:, 1] = F[:, 0], f(F[:, 0].view_as(x0), 1).view(bsz, -1)

    H = torch.zeros(bsz, m + 1, m + 1, dtype=x0.dtype, device=x0.device)
    H[:, 0, 1:] = H[:, 1:, 0] = 1
    y = torch.zeros(bsz, m + 1, 1, dtype=x0.dtype, device=x0.device)
    y[:, 0] = 1

    res = []

    iter_ = range(2, max_iter)

    for k in iter_:
        n = min(k, m)
        G = F[:, :n] - X[:, :n]
        H[:, 1:n + 1, 1:n + 1] = torch.bmm(G, G.transpose(1, 2)) + lam * torch.eye(n, dtype=x0.dtype, device=x0.device)[
            None]

        alpha = torch.linalg.solve(H[:, :n + 1, :n + 1], y[:, :n + 1])[:, 1:n + 1, 0]

        X[:, k % m] = beta * (alpha[:, None] @ F[:, :n])[:, 0] + (1 - beta) * (alpha[:, None] @ X[:, :n])[:, 0]
        F[:, k % m] = f(X[:, k % m].view_as(x0), k).view(bsz, -1)
        res.append((F[:, k % m] - X[:, k % m]).norm().item() / (1e-5 + F[:, k % m].norm().item()))

        if res[-1] < tol:
            break

    return X[:, k % m].view_as(x0), res

class DEQ(nn.Module):

    def __init__(self, ForwardI):
        super().__init__()

        self.ForwardI = ForwardI

    def forward(self, field, moving, fixed):

        with torch.no_grad():
            field_fixed, forward_res = anderson_solver(
                lambda field,interation: self.ForwardI(field, moving, fixed, interation, "forward"), field,
                max_iter=500,
                tol=1e-3,
            )

            forward_iter = len(forward_res)
            forward_res = forward_res[-1]
        
        if self.training == True:
            field_hat = self.ForwardI(field_fixed, moving, fixed, forward_iter, "backward")
        else:
            field_hat = self.ForwardI(field_fixed, moving, fixed, forward_iter, "forward")
        # JFB

        return field_hat, forward_iter, forward_res



